# Copyright (c) InternLM. All rights reserved.
"""
python tools/print_loaded_config.py --file <config_file> [--old_version]
"""

import argparse
import ast
import os
import re
import tempfile
import time
from functools import partial
from typing import Callable, Union

import astor
import jsonlines

from internlm.config import Config, ConfigDict, OldConfig
from internlm.core.context import global_context as gpc
from internlm.data.streaming.packed_dataset import get_packed_dataset

# pylint: disable=W0401, W0614
from internlm.data.streaming.tokenizer_wrappers import *
from internlm.initialize.launch import args_sanity_check


class ConfigConfirmTransformer(ast.NodeTransformer):
    """
    Convert old config to new config. Ask input if not given.
    """

    def __init__(self, target_line, lineno, compare=True, assign=True, import_from=True):
        self.target_line = target_line
        self.lineno = lineno
        self.stop_traversal = False
        self.compare = compare
        self.assign = assign
        self.import_from = import_from

    def visit_Compare(self, node):
        """
        比较语句报错，如环境变量不存在：
        config file: 19     if os.environ['CLUSTER_NAME'] == 'A800':
        interactive output: Enter the new value for 'os.environ['CLUSTER_NAME']
        """
        if self.stop_traversal or not self.compare:
            return node
        if node.lineno == self.lineno:
            left_value, comparator_value = node.left, node.comparators[0]
            if not isinstance(left_value, ast.Constant):
                var_name = astor.to_source(left_value)
                left_value = self.get_input_and_replace(f"Enter the new value for '{var_name}' at line {self.lineno}: ")
                left_value = ast.Constant(value=left_value)
                self.stop_traversal = True
            if not isinstance(comparator_value, ast.Constant):
                var_name = astor.to_source(comparator_value)
                comparator_value = self.get_input_and_replace(
                    f"Enter the new value for '{var_name}' at line {self.lineno}: "
                )
                comparator_value = ast.Constant(value=comparator_value)
            if self.stop_traversal:
                return ast.Compare(left=left_value, ops=node.ops, comparators=[comparator_value])
        return self.generic_visit(node)

    def visit_Assign(self, node):
        """
        赋值语句报错，如环境变量不存在：
        config file: 14    JOB_NAME = os.environ["JOB_NAME"]
        interactive output: Enter the new value for 'JOB_NAME' at line 14
        """
        if self.stop_traversal or not self.assign:
            return node
        if node.lineno == self.lineno and isinstance(node.targets[0], ast.Name):
            variable_name = node.targets[0].id
            if variable_name:
                new_value = self.get_input_and_replace(
                    f"Enter the new value for '{variable_name}' at line {self.lineno}: "
                )
                self.stop_traversal = True
                return ast.Assign(targets=node.targets, value=ast.Constant(value=new_value))
        return self.generic_visit(node)

    def visit_ImportFrom(self, node):
        """
        import from语句报错，如模块不存在：
        config file: 14    from internlm.utils.common import read_base
        自动修改： from internlm.config.config import read_base
        """
        if self.stop_traversal or not self.import_from:
            return node
        if node.lineno == self.lineno and node.module == "internlm.utils.common":
            new_names, new_import = [], None
            for alias in node.names:
                if alias.name == "read_base":
                    new_import = ast.ImportFrom(
                        module="internlm.config.config", names=[ast.alias(name="read_base")], level=0
                    )
                    ast.fix_missing_locations(new_import)
                    self.stop_traversal = True
                    continue
                else:
                    new_names.append(alias)

            if new_names:
                node.names = new_names
                if new_import:
                    return [node, new_import]
                else:
                    return node
            else:
                return new_import

        return self.generic_visit(node)

    def get_input_and_replace(self, input_intro):
        print(f"Original line {self.lineno}: {self.target_line.strip()}")
        new_value = input(input_intro)
        print(f"Line {self.lineno} replaced with: {new_value}")
        return new_value


def transform_code(code, target_line, lineno, **kwargs):
    tree = ast.parse(code)
    transformer = ConfigConfirmTransformer(target_line, lineno, **kwargs)
    new_tree = transformer.visit(tree)
    ast.fix_missing_locations(new_tree)
    return new_tree


def split_string_keep_newline(s):
    parts = re.split(r"(?<=\n)", s)
    return parts


def clarify_unknown_provider(exception, temp_file, **kwargs):
    tb = exception.__traceback__
    while tb.tb_next:
        tb = tb.tb_next
        code_filename = tb.tb_frame.f_code.co_filename
        code_lineno = tb.tb_lineno
        if code_filename == temp_file:
            with open(temp_file) as f:
                lines = f.readlines()
                line = lines[code_lineno - 1]

                code = transform_code("".join(lines), line, code_lineno, **kwargs)
                with open(temp_file, "w") as file:
                    file.writelines(split_string_keep_newline(astor.to_source(code)))
                return True
    return False


def remove_undefined_attributes(exception, temp_file):
    if not isinstance(exception, NameError):
        return False
    pattern = r"name '([^']*)' is not defined"
    matches = re.findall(pattern, str(exception))
    if "__file__" in matches:
        substitute_base_file_path(temp_file, temp_file, {"__file__": "os.getcwd()"})
        return True
    return False


def substitute_base_file_path(filename: str, temp_config_name: str, base_vars: dict = None) -> dict:
    with open(filename, encoding="utf-8") as f:
        config_file = f.read()
    assert base_vars is not None
    for key, value in base_vars.items():
        config_file = re.sub(key, value, config_file)
    with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
        tmp_config_file.write(config_file)


correction_funcs = [remove_undefined_attributes, partial(clarify_unknown_provider, import_from=False)]


def load_new_config_with_handling_exception(filename, retries=3, delay=2, delete=False, target_path=None):
    """
    尝试加载config，如果发生异常则使用修正方法修正参数后重试。

    Parameters:
    filename: 配置文件路径
    retries: 最大重试次数
    delay: 每次重试之间的延迟时间（秒）
    """
    temp_file, temp_dir = None, None
    try:
        filename = os.path.abspath(filename)
        base_map = {}
        for base_cfg_path in Config._get_base_files(filename):
            full_path, _ = Config._get_cfg_path(base_cfg_path, filename)
            base_map[base_cfg_path] = os.path.normpath(full_path)

        # pylint: disable=R1732
        temp_dir = tempfile.TemporaryDirectory()
        temp_file = os.path.join(temp_dir.name, os.path.basename(filename))
        print("writing config in file", temp_file)
        substitute_base_file_path(filename, temp_file, base_map)

        attempt = 0
        while attempt < retries:
            try:
                gpc.load_config(Config.fromfile(temp_file))
                break
            except Exception as e:
                # import traceback
                # traceback.print_exc()
                if attempt < retries - 1:
                    for correction_func in correction_funcs:
                        if correction_func(e, temp_file):
                            break
                    else:
                        attempt += 1
                        print(f"Retrying in {delay} seconds...")
                        time.sleep(delay)
                else:
                    print("Max retries reached. Exiting.")
                    raise e
    finally:
        if target_path:
            with open(target_path, "w", encoding="utf-8") as tmp_config_file:
                tmp_config_file.write(gpc.config.pretty_text)
        if delete and temp_dir:
            temp_dir.cleanup()


def get_tokenizer_wrapper(tokenizer_wrapper: Union[str, Callable], multimodal_cfg):
    if callable(tokenizer_wrapper):
        return tokenizer_wrapper
    if multimodal_cfg:
        if tokenizer_wrapper == "pretrain":
            TokenizationWrapper = StreamingMultimodalPretrainTokenizationWrapper
        else:
            raise NotImplementedError(
                f"tokenizer_wrapper {tokenizer_wrapper} for multimodal is not implemented. "
                "Check data.streaming.tokenization_wrapper for more details."
            )
    else:
        if tokenizer_wrapper == "pretrain":
            TokenizationWrapper = StreamingPretrainTokenizationWrapper
        elif tokenizer_wrapper == "pretrain_with_loss_mask":
            TokenizationWrapper = StreamingPretrainTokenizationWrapperWithLossMask
        elif tokenizer_wrapper == "sft":
            TokenizationWrapper = StreamingSFTTokenizationWrapper
        elif tokenizer_wrapper == "sft_multi_round":
            TokenizationWrapper = StreamingSFTMultiRoundTokenizationWrapper
        elif tokenizer_wrapper == "s1_multi_round":
            TokenizationWrapper = StreamingS1MultiRoundTokenizationWrapper
        elif tokenizer_wrapper == "fim":
            TokenizationWrapper = StreamingFIMTokenizationWrapper
        elif tokenizer_wrapper == "online_prompt":
            TokenizationWrapper = StreamingOnlinePromptTokenizationWrapper
        elif tokenizer_wrapper == "word_prompt":
            TokenizationWrapper = StreamingWordPromptTokenizationWrapper
        elif tokenizer_wrapper == "base_format":
            TokenizationWrapper = StreamingBaseFormatTokenizationWrapper
        elif tokenizer_wrapper == "io_format":
            TokenizationWrapper = StreamingIOFormatTokenizationWrapper
        elif tokenizer_wrapper == "simple_pretrain":
            TokenizationWrapper = SimpleStreamingTokenizationWrapper
        elif tokenizer_wrapper == "flex":
            TokenizationWrapper = StreamingFlexTokenizationWrapper
        elif tokenizer_wrapper == "pass_through":
            TokenizationWrapper = StreamingPassThroughTokenizationWrapper
        else:
            raise NotImplementedError(
                f"tokenizer_wrapper {tokenizer_wrapper} is not implemented. "
                "Check data.streaming.tokenization_wrapper for more details."
            )
    return TokenizationWrapper


def get_input_and_replace_for_import(lines, lineno, module_name, imported_names, indentation):

    if "read_base" in imported_names:
        imported_names.remove("read_base")
        if imported_names:
            lines[lineno - 1] = f"{indentation}from {module_name} import ({','.join(imported_names)})\n"
            lines.insert(lineno, "from internlm.config.config import read_base\n")
        else:
            lines[lineno - 1] = "from internlm.config.config import read_base\n"
    return lines


def append_python_path(filename, temp_config_name, base_path=None, base_map=None):
    with open(filename, encoding="utf-8") as f:
        config_file = f.read()
    if base_path:
        config_file = "\n".join(base_path) + "\n" + config_file
    if base_map:
        for key, value in base_map.items():
            config_file = re.sub(key, value, config_file)
    with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
        tmp_config_file.write(config_file)
    return base_path


old_correction_funcs = [clarify_unknown_provider]


def load_old_config_with_handling_exception(filename, retries=3, delay=2, delete=False, target_path=None):
    temp_file, temp_dir = None, None
    try:
        filename = os.path.abspath(filename)
        base_path = [
            "import sys",
            "import os",
            "",
            f'sys.path.append("{os.path.dirname(filename)}")',
            "sys.path.append(os.getcwd())",
        ]

        base_map = {"import os": ""}

        # pylint: disable=R1732
        temp_dir = tempfile.TemporaryDirectory()
        temp_file = os.path.join(temp_dir.name, os.path.basename(filename))

        print("writing config in file", temp_file)
        append_python_path(filename, temp_file, base_map=base_map)
        append_python_path(temp_file, temp_file, base_path=base_path)

        attempt = 0
        while attempt < retries:
            try:
                gpc.load_config(OldConfig.fromfile(temp_file))
                gpc.config.data = convert_old_conf_to_new_format(gpc.config.data)
                break
            except Exception as e:
                # import traceback
                # traceback.print_exc()
                if attempt < retries - 1:
                    for correction_func in old_correction_funcs:
                        if correction_func(e, temp_file):
                            break
                    else:
                        attempt += 1
                        print(f"Retrying in {delay} seconds...")
                        time.sleep(delay)
                else:
                    print("Max retries reached. Exiting.")
                    raise e
    finally:
        if target_path:
            with open(target_path, "w", encoding="utf-8") as tmp_config_file:
                tmp_config_file.write(gpc.config.pretty_text)
        if delete and temp_dir:
            temp_dir.cleanup()


def get_import_name(cls_or_funcs):
    if isinstance(cls_or_funcs, str):
        return cls_or_funcs
    return f"{cls_or_funcs.__module__}.{cls_or_funcs.__name__}"


def get_data_parallel_info(local_rank, world_size, pp_size, sp_size):
    num_per_pp = world_size // pp_size
    dp_size = max(1, num_per_pp // sp_size)
    dp_rank = (local_rank % num_per_pp) // sp_size
    return dp_rank, dp_size


jsonl_dataset_keys = [
    "debug",
    "data_rank",
    "data_world_size",
    "meta_folder",
    "mmap_meta_stage",
    "attributes",
    "dynamic_attributes_func_str",
    "prev_filter_func_str",
    "lazy_load_attribute",
    "multimodal_params",
]

jsonl_dataset_mapping = {
    "prev_filter_func": lambda func: {"prev_filter_func_str": get_import_name(func)},
}

tokenizer_wrapper_keys = [
    "text_field",
    "prompt_text_field",
    "output_text_field",
    "multimodal_cfg",
    "max_length_per_sample",
    "min_length",
    "tokenizer_cfg",
    "tokenizer_chunk_num",
    "tokenizer_chunk",
    "overlap_length",
    "online_prompt",
    "online_prompt_jsonl_path",
    "fim_conf",
    "always_bos",
    "process_func",
    "keep_extra_info",
    "skip_content_conf",
    "filling_type",
    "loss_token_num",
    "inject_dyname_word_prompt",
    "inject_dyname_token_nums_prompt",
]

packed_dataset_keys = ["packed_length", "drop_last", "return_unpacked", "for_valid", "compute_subset_loss"]

weighted_dataset_keys = ["dataset_weights", "seed", "epochs_to_use"]


def remove_keys(data_config, pop_keys):
    for param in pop_keys:
        if param in data_config:
            data_config.pop(param)
    return data_config


def aggregate_dict(data_config, dataset_type=None, aggragate_keys: list = None, aggragate_map: dict = None):
    if dataset_type:
        dataset_cfg = {"type": dataset_type}
    else:
        dataset_cfg = {}
    if data_config is not None:
        if aggragate_keys is not None:
            for param in aggragate_keys:
                if param in data_config:
                    dataset_cfg[param] = data_config[param]
        if aggragate_map is not None:
            for config_name, funcs in aggragate_map.items():
                if config_name in data_config:
                    dataset_cfg.update(funcs(data_config[config_name]))
    return dataset_cfg


def aggregate_dataset_config(data_cfg, subset=True):
    multimodal_cfg = data_cfg.get("multimodal_cfg", None)

    if multimodal_cfg:
        jsonl_type = "internlm.data.streaming.jsonl_dataset.StreamingMultimodalJsonlDataset"
    elif data_cfg.get("meta_folder", None):
        jsonl_type = "internlm.data.streaming.jsonl_dataset.StreamingAttributeJsonlDataset"
    elif not subset:
        jsonl_type = "internlm.data.streaming.jsonl_dataset.StreamingJsonlDataset"
    else:
        jsonl_type = None

    if jsonl_dataset_cfg := aggregate_dict(data_cfg, jsonl_type, jsonl_dataset_keys, jsonl_dataset_mapping):
        data_cfg["jsonl_dataset_cfg"] = jsonl_dataset_cfg

    if not subset:
        data_cfg["aggragation_dataset_cfg"] = {
            "type": "internlm.data.streaming.aggregation_dataset.StreamingAggregationDataset"
        }
        if data_cfg.get("debug", False):
            data_cfg["aggragation_dataset_cfg"]["debug"] = True
    elif data_cfg.get("debug", False):
        data_cfg["aggragation_dataset_cfg"] = {"debug": True}

    if not subset or multimodal_cfg is not None or "tokenizer_wrapper" in data_cfg:
        tokenizer_wrapper_type = get_tokenizer_wrapper(data_cfg.pop("tokenizer_wrapper", "pretrain"), multimodal_cfg)
        tokenizer_wrapper_import_name = get_import_name(tokenizer_wrapper_type)
    else:
        tokenizer_wrapper_import_name = None

    if tokenizer_wrapper_cfg := aggregate_dict(data_cfg, tokenizer_wrapper_import_name, tokenizer_wrapper_keys):
        data_cfg["tokenizer_wrapper_cfg"] = tokenizer_wrapper_cfg

    if not subset or multimodal_cfg is not None or "break_mode" in data_cfg:
        packed_dataset_type = get_packed_dataset(data_cfg.pop("break_mode", "cut"), multimodal_cfg)
        packed_dataset_import_name = get_import_name(packed_dataset_type)
    else:
        packed_dataset_import_name = None

    if packed_dataset_cfg := aggregate_dict(data_cfg, packed_dataset_import_name, packed_dataset_keys):
        data_cfg["packed_dataset_cfg"] = packed_dataset_cfg

    if not subset:

        weighted_dataset_type = "internlm.data.streaming.weighted_dataset.StreamingWeightedDataset"

        data_cfg["weighted_dataset_cfg"] = aggregate_dict(data_cfg, weighted_dataset_type, weighted_dataset_keys)

    data_cfg = remove_keys(
        data_cfg, jsonl_dataset_keys + tokenizer_wrapper_keys + packed_dataset_keys + weighted_dataset_keys
    )

    return ConfigDict(data_cfg)


def convert_old_conf_to_new_format(data_cfg):
    data_cfg = ConfigDict(data_cfg)
    if data_cfg.type == "streaming":
        new_subset_cfg = {}
        data_cfg = aggregate_dataset_config(data_cfg, False)
        if "subset_params" in data_cfg:
            for subset_name, subset_param in data_cfg.pop("subset_params").items():
                new_subset_cfg[subset_name] = aggregate_dataset_config(subset_param)
        data_cfg["subset_params_cfg"] = new_subset_cfg
    return data_cfg


def generate_online_prompt_jsonl(prompt_weights, lang, file_name, mode="w"):
    """
    CN_ONLINE_PROMPT = {
        "aprox_word_num":{
            "以下是大约{}个汉字的文本。\n": 0.1,
            "这是一段长度约为{}个字的段落。\n": 0.1,
            "下面的文本大约包含{}个汉字。\n": 0.1,
            "这是一个大约有{}个汉字的段落。\n": 0.1,
            "后续内容大约为{}个字长。\n": 0.1,
        },
        "minimum_word_num":{
            "此处呈现的是至少包含{}个汉字的文本。\n": 0.1,
            "这里有一段文本摘录，确保它超出了{}字。\n": 0.1,
            "下面，您会发现一篇包含不少于{}个字的文本。\n": 0.1,
            "接下来是至少包含{}个汉字的段落。\n": 0.1,
            "附上至少包含{}个汉字的段落。\n": 0.1,
        }
    }

    generate_online_prompt_jsonl(CN_ONLINE_PROMPT, "cn", "data.jsonl", mode="w")

    EN_ONLINE_PROMPT = {
        "aprox_word_num":{
            "The text below comprises roughly {} words.\n": 0.1,
            "Here's a passage that's about {} words long.\n": 0.1,
            "What comes next is a piece of text that's approximately {} words.\n": 0.1,
            "This passage contains roughly {} words.\n": 0.1,
            "Presented here is a paragraph of around {} words.\n": 0.1,
        },
        "minimum_word_num":{
            "Presented here is a block of text with a minimum of {} words.\n": 0.1,
            "Here lies a textual excerpt, ensuring it surpasses the {}-word mark.\n": 0.1,
            "Below, you'll find a composition comprising no less than {} words.\n": 0.1,
            "What follows is a paragraph containing a minimum of {} words.\n": 0.1,
            "Enclosed is a passage consisting of at least {} words.\n": 0.1,
        }
    }

    generate_online_prompt_jsonl(EN_ONLINE_PROMPT, "en", "data.jsonl", mode="a")
    """
    data = []

    for key, value in prompt_weights.items():
        for prompt, per in value.items():
            data.append({"lang": lang, "type": key, "prompt": prompt, "prob": per})

    with jsonlines.open(file_name, mode=mode) as writer:
        writer.write_all(data)


def parse_args():
    parser = argparse.ArgumentParser(description="interactive config generation.")

    parser.add_argument("-f", "--file", type=str, required=True, help="The config file")
    parser.add_argument("-r", "--retries", type=int, default=3, help="Exception retry times")
    parser.add_argument("-d", "--delay", type=int, default=2, help="Delay seconds for retry")
    parser.add_argument("-t", "--target_path", type=str, default=None, help="Write generated config to tmp file")
    parser.add_argument("-D", "--delete", default=False, action="store_true", help="Delete generated tmp config")
    parser.add_argument("-o", "--old_version", default=False, action="store_true", help="Old version config")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if args.old_version:
        load_old_config_with_handling_exception(args.file, args.retries, args.delay, args.delete, args.target_path)
    else:
        load_new_config_with_handling_exception(args.file, args.retries, args.delay, args.delete, args.target_path)
    args_sanity_check()
    print(gpc.config.pretty_text)
